
from torch import nn

class GTU(nn.Module):
    def __init__(self):
        super(GTU, self).__init__()
        
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        
        x_ = self.tanh(x)
        gate_ = self.sigmoid(x)
        return x_ * gate_
        